/*--------------------------------------------------------------------------------------------------
| This file is distributed under the MIT License.
| See accompanying file /LICENSE for details.
| Author(s): Bruno Schmitt
*-------------------------------------------------------------------------------------------------*/
#pragma once

#include "../gates/gate_lib.hpp"
#include "../views/depth_view.hpp"

#include <algorithm>
#include <cstdint>
#include <fmt/format.h>
#include <fstream>
#include <iostream>
#include <vector>

namespace tweedledum {

/*! \brief Writes network in DOT format into output stream
 *
 * An overloaded variant exists that writes the network into a file.
 *
 * **Required gate functions:**
 * - `is`
 * - `is_control`
 * - `foreach_target`
 *
 * **Required network functions:**
 * - `get_node`
 * - `clear_marks`
 * - `mark`
 * - `foreach_child`
 * - `foreach_gate`
 * - `foreach_input`
 * - `foreach_output`
 *
 * \param network Network
 * \param os Output stream
 */
template <typename Network>
void write_dot(Network const& dag_network, std::ostream& os = std::cout) {
    depth_view network_view(dag_network);

    os << "# Quantum DAG structure generated by tweedledum package\n";
    os << "digraph QuantumNet {\n";
    os << "\trankdir = \"RL\";\n";
    os << "\tsize = \"7.5,10\";\n";
    os << "\tcenter = true;\n";
    os << "\tedge [dir = back];\n\n";

    os << "\t{\n";
    os << "\t\tnode [shape = plaintext];\n";
    os << "\t\tedge [style = invis];\n";
    for (auto level = network_view.depth() + 1; level-- > 0;) {
        os << fmt::format("\t\tLevel{} [label = \"\"];\n", level);
    }
    os << "\t\t";
    for (auto level = network_view.depth() + 1; level-- > 0;) {
        os << fmt::format("Level{}{}", level, level != 0 ? " -> " : ";");
    }
    os << "\n\t}\n";

    os << "\t{\n";
    os << "\t\trank = same;\n";
    os << "\t\tnode[shape = cds, style = filled];\n";
    os << fmt::format("\t\tLevel{};\n", network_view.depth());
    network_view.foreach_output([&](auto const& node, auto index) {
        io_id output_id = node.gate.target();
        os << fmt::format(
            "\t\tNode{} [label = \"{}\", xlabel = \"{}\", fillcolor = {}];\n",
            index, network_view.io_label(output_id), index,
            output_id.is_qubit() ? "lightcoral" : "lightblue");
    });
    os << "\t}\n";

    for (auto level = network_view.depth(); level-- > 1;) {
        os << "\t{\n";
        os << "\t\trank = same;\n";
        os << fmt::format("\t\tLevel{};\n", level);
        network_view.foreach_gate([&network_view, &os, level](auto const& node,
                                                              auto index) {
            if (network_view.level(node) != level) {
                return;
            }
            os << fmt::format(
                "\t\tNode{} [label = \"{}\", xlabel = \"{}\", shape = {}];\n",
                index, node.gate.symbol(), index,
                node.gate.is(gate_lib::cx) ? "doublecircle" : "circle");
        });
        os << "\t}\n";
    }

    os << "\t{\n";
    os << "\t\trank = same;\n";
    os << "\t\tnode[shape = cds, style = filled];\n";
    os << "\t\tLevel0;\n";
    network_view.foreach_input([&](auto const& node, auto index) {
        io_id input_id = node.gate.target();
        os << fmt::format(
            "\t\tNode{} [label = \"{}\", xlabel = \"{}\", fillcolor = {}];\n",
            index, network_view.io_label(input_id), index,
            input_id.is_qubit() ? "lightcoral" : "lightblue");
    });
    os << "\t}\n\n";

    network_view.foreach_gate(
        [&network_view, &os](auto const& node, auto index) {
            network_view.foreach_child(node, [&](auto child) {
                os << fmt::format("\tNode{} -> Node{} [style = solid];\n",
                                  index, child.index);
            });
        });
    network_view.foreach_output(
        [&network_view, &os](auto const& node, auto index) {
            network_view.foreach_child(node, [&](auto child) {
                os << fmt::format("\tNode{} -> Node{} [style = solid];\n",
                                  index, child.index);
            });
        });
    os << "}\n";
}

/*! \brief Writes network in DOT format into a file
 *
 * **Required gate functions:**
 * - `is`
 * - `is_control`
 * - `foreach_target`
 *
 * **Required network functions:**
 * - `get_node`
 * - `clear_marks`
 * - `mark`
 * - `foreach_child`
 * - `foreach_gate`
 * - `foreach_input`
 * - `foreach_output`
 *
 * \param network Network
 * \param filename Filename
 */
template <typename Network>
void write_dot(Network const& network, std::string const& filename) {
    std::ofstream os(filename.c_str(), std::ofstream::out);
    write_dot(network, os);
}

} // namespace tweedledum
